import math
import json
import os
import asyncio
from collections import defaultdict, deque
from src.relevance_filter import RelevanceFilter

class GraphBuilder:
    def __init__(self, cfg, api_client):
        self.cfg = cfg
        self.api_client = api_client
        self.relevance_filter = RelevanceFilter(cfg, api_client)
        self.decay_factor = cfg.get("decay_factor", 0.8)
        self.max_depth = cfg.get("max_depth", 2)
        self.batch_size = cfg.get("batch_size", 10)
        self.base_dir = cfg.get("base_dir", "entity_graphs")
        os.makedirs(self.base_dir, exist_ok=True)

    async def build(self, initial_entity):
        graph = defaultdict(list)
        visited = set()
        queue = deque([(initial_entity.lower(), 0)])

        while queue:
            current_level_nodes = []
            current_depth = queue[0][1] if queue else 0

            # gather all nodes at current depth
            while queue and queue[0][1] == current_depth:
                node, depth = queue.popleft()
                if node not in visited:
                    current_level_nodes.append(node)

            if not current_level_nodes or current_depth > self.max_depth:
                continue

            visited.update(current_level_nodes)
            max_nodes = math.ceil(len(current_level_nodes) * (self.decay_factor ** max(current_depth - 1, 0)))
            current_level_nodes = current_level_nodes[:max_nodes]

            print(f"[GraphBuilder] Depth {current_depth}: expanding {len(current_level_nodes)} nodes")

            triples_per_node = await self._get_atomic_facts_batch(current_level_nodes)

            for node, triples in zip(current_level_nodes, triples_per_node):
                triples = await self.relevance_filter.filter_triples(triples, initial_entity)
                for subj, rel, obj in triples:
                    graph[subj].append((rel, obj))
                    if obj not in visited:
                        queue.append((obj, current_depth + 1))

        return dict(graph)

    async def _get_atomic_facts_batch(self, entities):
        tasks = [self._extract_atomic_facts(entity) for entity in entities]
        return await asyncio.gather(*tasks)

    async def _extract_atomic_facts(self, entity):
        # Generates a simple prompt for triple extraction
        prompt = f"{entity} is"
        extraction_prompt = f"""
        Extract all (subject, relation, object) facts about "{entity}" from the following text:

        Text: "{prompt}"

        Format:
        (subject ~ relation ~ object)
        Only output the list.
        """
        response = await self.api_client.call("default", extraction_prompt, for_graph=True)
        triples = self._parse_response(response)
        return triples

    def _parse_response(self, text):
        pattern = r'\(([^~]+)~([^~]+)~([^)]+)\)'
        triples = [(a.strip(), b.strip(), c.strip()) for a, b, c in re.findall(pattern, text)]
        return triples

    def save(self, entity, graph):
        save_path = os.path.join(self.base_dir, f"{entity.replace(' ', '_')}.json")
        with open(save_path, "w") as f:
            json.dump(graph, f, indent=4)
        print(f"[GraphBuilder] Saved graph to {save_path}")
